import torch
import torch.nn as nn

# --- Helper MLP Modules for the Coupling Layers ---

class MLPTransform(nn.Module):
    """A standard MLP used for scale and shift transformations in the image flow."""
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim // 2, dim),
            nn.LayerNorm(dim),
            nn.ReLU(),
            nn.Linear(dim, dim // 2)
        )

    def forward(self, x):
        # The input tensor shape is (B, N, D), we process each token's feature vector.
        B, N, D = x.shape
        x = x.view(-1, D)      # Reshape to [B*N, D] for MLP
        x = self.net(x)
        x = x.view(B, N, -1)   # Reshape back to original
        return x

class CausalMLPTransform(nn.Module):
    """A causal MLP for scale and shift transformations in the time-series flow."""
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim // 2, dim),
            nn.ReLU(),
            nn.Linear(dim, dim // 2)
        )

    def forward(self, x):
        B, N, D = x.shape
        x = x.view(-1, D)      # Reshape to [B*N, D]
        x = self.net(x)
        x = x.view(B, N, -1)
        return x

# --- Core Coupling Layer ---

class RealNVPCouplingLayer(nn.Module):
    """
    Implements a single RealNVP coupling layer.
    It splits the input into two halves, using one half to compute scale (s) and shift (t)
    parameters to transform the other half.
    """
    def __init__(self, dim, causal=False):
        super().__init__()
        Transform = CausalMLPTransform if causal else MLPTransform
        self.scale_net = Transform(dim)
        self.shift_net = Transform(dim)

    def forward(self, x, reverse=False):
        # Split the feature dimension into two halves
        x1, x2 = x.chunk(2, dim=-1)
        
        # Compute scale and shift from the first half
        s = self.scale_net(x1)
        t = self.shift_net(x1)
        
        # Stabilize the scale factor
        s = torch.tanh(s) * 5.0 # Example: Use tanh for stability, scaling is arbitrary

        if reverse:
            # Inverse transformation
            x2 = (x2 - t) * torch.exp(-s)
        else:
            # Forward transformation
            x2 = x2 * torch.exp(s) + t

        return torch.cat([x1, x2], dim=-1)

# --- Main Normalizing Flow Modules ---

class RealNVPImageFlow(nn.Module):
    """A stack of RealNVP coupling layers for PRPD images, with permutations."""
    def __init__(self, dim=768, num_layers=6):
        super().__init__()
        # Note: Setting a seed here ensures permutations are the same across runs,
        # but a global seed in the main script is generally preferred.
        torch.manual_seed(42)

        self.layers = nn.ModuleList([
            RealNVPCouplingLayer(dim, causal=False) for _ in range(num_layers)
        ])
        
        # Register permutations as non-trainable buffers
        for i in range(num_layers):
            perm = torch.randperm(dim)
            self.register_buffer(f"perm_{i}", perm)

    def forward(self, x, reverse=False):
        layers_to_iterate = reversed(list(enumerate(self.layers))) if reverse else enumerate(self.layers)
        
        for i, layer in layers_to_iterate:
            perm = getattr(self, f"perm_{i}")
            if not reverse:
                x = x[:, :, perm] # Apply permutation
            
            x = layer(x, reverse)
            
            if reverse:
                inv_perm = torch.argsort(perm)
                x = x[:, :, inv_perm] # Apply inverse permutation
        return x

class RealNVPSignalFlow(nn.Module):
    """A stack of causal RealNVP coupling layers for time-series data, with permutations."""
    def __init__(self, dim=768, num_layers=8):
        super().__init__()
        # Note: A global seed in the main script is generally preferred for reproducibility.
        torch.manual_seed(42)

        self.layers = nn.ModuleList([
            RealNVPCouplingLayer(dim, causal=True) for _ in range(num_layers)
        ])
        
        for i in range(num_layers):
            perm = torch.randperm(dim)
            self.register_buffer(f"perm_{i}", perm)

    def forward(self, x, reverse=False):
        layers_to_iterate = reversed(list(enumerate(self.layers))) if reverse else enumerate(self.layers)

        for i, layer in layers_to_iterate:
            perm = getattr(self, f"perm_{i}")
            if not reverse:
                x = x[:, :, perm] # Apply permutation
            
            x = layer(x, reverse)
            
            if reverse:
                inv_perm = torch.argsort(perm)
                x = x[:, :, inv_perm] # Apply inverse permutation
        return x